{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/daiyuxin/anaconda3/envs/mindspore/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.utils import http_get\n",
    "\n",
    "url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'\n",
    "path = http_get(url, './')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50000"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from mindspore.dataset import TextFileDataset\n",
    "\n",
    "dataset = TextFileDataset(str(path), shuffle=False)\n",
    "dataset.get_dataset_size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# article: [CLS] xxxxx [SEP]\n",
    "# summary: [CLS] xxxxx [SEP]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "\n",
    "def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):\n",
    "    def read_map(text):\n",
    "        data = json.loads(text.tobytes())\n",
    "        return np.array(data['article']), np.array(data['summarization'])\n",
    "\n",
    "    def merge_and_pad(article, summary):\n",
    "        tokenized = tokenizer(text=article, text_pair=summary,\n",
    "                              padding='max_length', truncation='only_first', max_length=max_seq_len)\n",
    "        return tokenized[0], tokenized[0]\n",
    "    \n",
    "    dataset = dataset.map(read_map, 'text', ['article', 'summary'])\n",
    "    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])\n",
    "\n",
    "    dataset = dataset.batch(batch_size)\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(batch_size)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/daiyuxin/anaconda3/envs/mindspore/lib/python3.9/site-packages/urllib3/connectionpool.py:1100: InsecureRequestWarning: Unverified HTTPS request is being made to host 'modelscope.cn'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings\n",
      "  warnings.warn(\n",
      "/home/daiyuxin/anaconda3/envs/mindspore/lib/python3.9/site-packages/urllib3/connectionpool.py:1100: InsecureRequestWarning: Unverified HTTPS request is being made to host 'modelscope.cn'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings\n",
      "  warnings.warn(\n",
      "/home/daiyuxin/anaconda3/envs/mindspore/lib/python3.9/site-packages/urllib3/connectionpool.py:1100: InsecureRequestWarning: Unverified HTTPS request is being made to host 'modelscope.cn'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings\n",
      "  warnings.warn(\n",
      "/home/daiyuxin/anaconda3/envs/mindspore/lib/python3.9/site-packages/urllib3/connectionpool.py:1100: InsecureRequestWarning: Unverified HTTPS request is being made to host 'modelscope.cn'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.transformers import BertTokenizer\n",
    "\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Tensor(shape=[4, 1024], dtype=Int64, value=\n",
       " [[ 101, 1724, 3862 ...    0,    0,    0],\n",
       "  [ 101,  704, 3173 ...    0,    0,    0],\n",
       "  [ 101, 1079, 2159 ... 1745, 8021,  102],\n",
       "  [ 101, 1355, 2357 ...    0,    0,    0]]),\n",
       " Tensor(shape=[4, 1024], dtype=Int64, value=\n",
       " [[ 101, 1724, 3862 ...    0,    0,    0],\n",
       "  [ 101,  704, 3173 ...    0,    0,    0],\n",
       "  [ 101, 1079, 2159 ... 1745, 8021,  102],\n",
       "  [ 101, 1355, 2357 ...    0,    0,    0]])]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(train_dataset.create_tuple_iterator())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21128"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import ops\n",
    "from mindnlp.transformers import GPT2LMHeadModel\n",
    "\n",
    "class GPT2ForSummarization(GPT2LMHeadModel):\n",
    "    def construct(\n",
    "        self,\n",
    "        input_ids = None,\n",
    "        attention_mask = None,\n",
    "        labels = None,\n",
    "    ):\n",
    "        outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)\n",
    "        shift_logits = outputs.logits[..., :-1, :]\n",
    "        shift_labels = labels[..., 1:]\n",
    "        # Flatten the tokens\n",
    "        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import ops\n",
    "from mindspore.nn.learning_rate_schedule import LearningRateSchedule\n",
    "\n",
    "class LinearWithWarmUp(LearningRateSchedule):\n",
    "    \"\"\"\n",
    "    Warmup-decay learning rate.\n",
    "    \"\"\"\n",
    "    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):\n",
    "        super().__init__()\n",
    "        self.learning_rate = learning_rate\n",
    "        self.num_warmup_steps = num_warmup_steps\n",
    "        self.num_training_steps = num_training_steps\n",
    "\n",
    "    def construct(self, global_step):\n",
    "        if global_step < self.num_warmup_steps:\n",
    "            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate\n",
    "        return ops.maximum(\n",
    "            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))\n",
    "        ) * self.learning_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 5\n",
    "warmup_steps = 2000\n",
    "learning_rate = 1.5e-4\n",
    "\n",
    "num_training_steps = num_epochs * train_dataset.get_dataset_size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from mindspore import nn\n",
    "from mindnlp.transformers import GPT2Config, GPT2LMHeadModel\n",
    "\n",
    "config = GPT2Config(vocab_size=len(tokenizer))\n",
    "model = GPT2ForSummarization(config)\n",
    "\n",
    "lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)\n",
    "optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of model parameters: 102068736\n"
     ]
    }
   ],
   "source": [
    "# 记录模型参数数量\n",
    "print('number of model parameters: {}'.format(model.num_parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Trainer will use 'StaticLossScaler' with `scale_value=2 ** 10` when `loss_scaler` is None.\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.engine import Trainer\n",
    "from mindnlp.engine.callbacks import CheckpointCallback\n",
    "\n",
    "ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',\n",
    "                                epochs=1, keep_checkpoint_max=2)\n",
    "\n",
    "trainer = Trainer(network=model, train_dataset=train_dataset,\n",
    "                  epochs=5, optimizer=optimizer, callbacks=ckpoint_cb)\n",
    "trainer.set_amp(level='O1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The train will start from the checkpoint saved in 'checkpoint'.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0:   0%|                                                                                                         | 0/11250 [00:00<?, ?it/s][ERROR] CORE(1809837,7f95ef9896c0,python):2023-11-26-01:46:08.184.627 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1809837/729418715.py]\n",
      "[ERROR] CORE(1809837,7f95ef9896c0,python):2023-11-26-01:46:08.184.657 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1809837/729418715.py]\n",
      "[ERROR] CORE(1809837,7f95ef9896c0,python):2023-11-26-01:46:08.184.892 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1809837/729418715.py]\n",
      "[ERROR] CORE(1809837,7f95ef9896c0,python):2023-11-26-01:46:08.184.904 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1809837/729418715.py]\n",
      "Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████| 11250/11250 [1:23:37<00:00,  2.24it/s, loss=4.937714]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint: 'gpt2_summarization_epoch_0.ckpt' has been saved in epoch: 0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████| 11250/11250 [1:23:40<00:00,  2.24it/s, loss=4.7891397]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint: 'gpt2_summarization_epoch_1.ckpt' has been saved in epoch: 1.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████| 11250/11250 [1:23:44<00:00,  2.24it/s, loss=4.789195]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The maximum number of stored checkpoints has been reached.\n",
      "Checkpoint: 'gpt2_summarization_epoch_2.ckpt' has been saved in epoch: 2.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████| 11250/11250 [1:23:36<00:00,  2.24it/s, loss=4.789144]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The maximum number of stored checkpoints has been reached.\n",
      "Checkpoint: 'gpt2_summarization_epoch_3.ckpt' has been saved in epoch: 3.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|███████████████████████████████████████████████████████████████████████████| 11250/11250 [1:23:36<00:00,  2.24it/s, loss=4.7891893]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The maximum number of stored checkpoints has been reached.\n",
      "Checkpoint: 'gpt2_summarization_epoch_4.ckpt' has been saved in epoch: 4.\n"
     ]
    }
   ],
   "source": [
    "trainer.run(tgt_columns=\"labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):\n",
    "    def read_map(text):\n",
    "        data = json.loads(text.tobytes())\n",
    "        return np.array(data['article']), np.array(data['summarization'])\n",
    "\n",
    "    def pad(article):\n",
    "        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)\n",
    "        return tokenized[0]\n",
    "\n",
    "    dataset = dataset.map(read_map, 'text', ['article', 'summary'])\n",
    "    dataset = dataset.map(pad, 'article', ['input_ids'])\n",
    "    \n",
    "    dataset = dataset.batch(batch_size)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "batched_test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[array([[ 101, 3173, 1290, 5381, 6205, 2128,  128, 3299,  129, 3189, 4510,\n",
      "        8020, 6381, 5442, 6948, 1132,  840, 8021, 6381, 5442,  129, 3189,\n",
      "         678, 1286,  794, 7362, 6205, 4209,  689, 1265, 2339, 7415, 1730,\n",
      "         749, 6237, 1168, 8024,  754,  128, 3299,  127, 3189, 7506, 1814,\n",
      "        3433, 3409, 1790,  759,  678,  752, 3125,  704, 6158, 1737, 4638,\n",
      "         125, 1399, 4771, 2339, 2347, 4802, 6371, 1059, 6956, 3647,  767,\n",
      "         511,  128, 3299,  127, 3189, 1119, 3247,  125, 3198, 8123, 1146,\n",
      "        8024, 7362, 6205, 4689, 7506, 1814, 3433, 3409, 1790, 4209, 4771,\n",
      "        2963, 6822,  676, 7339, 1762, 4209, 4771,  122, 1384,  759, 1298,\n",
      "        7023, 1277,  671, 2339,  868, 7481, 6822, 6121,  868,  689, 3198,\n",
      "        1355, 4495,  671, 6629, 4209,  680, 4482, 3172, 4960, 1139,  752,\n",
      "        3125, 8024,  125, 1399, 4771, 2339, 6158, 1737,  511, 3131, 3001,\n",
      "         782, 1447,  754,  127, 3189,  678, 1286, 1355, 4385,  671, 1399,\n",
      "        6158, 1737, 4771, 2339, 8024, 5307, 1059, 1213, 2843, 3131, 3187,\n",
      "        3126, 3647,  767,  511, 5307, 3131, 2844,  782, 1447, 1059, 1213,\n",
      "        3131, 3001, 8024, 3297, 1400,  671, 1399, 6878, 7410, 4771, 2339,\n",
      "        6890,  860, 2347,  754,  129, 3189, 8110, 3198, 8216, 1146, 2823,\n",
      "        1168,  511, 5635, 3634, 8024,  125, 1399, 6878, 7410, 4771, 2339,\n",
      "        6890,  860, 1059, 6956, 2823, 1168, 1285,  759, 8024, 3131, 3001,\n",
      "        2339,  868, 5310, 3338,  511, 4771, 3175, 3633, 1762,  976, 1587,\n",
      "        1400, 2339,  868,  511,  102]], dtype=int64), array(['渭南韩城县桑树坪井下事故中被困4名矿工全部死亡，遗体今日中午均被找到。'], dtype='<U35')]\n"
     ]
    }
   ],
   "source": [
    "print(next(batched_test_dataset.create_tuple_iterator(output_numpy=True)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/daiyuxin/anaconda3/envs/mindspore/lib/python3.9/site-packages/urllib3/connectionpool.py:1100: InsecureRequestWarning: Unverified HTTPS request is being made to host 'modelscope.cn'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings\n",
      "  warnings.warn(\n",
      "Generation config file not found, using a generation config created from the model config.\n"
     ]
    }
   ],
   "source": [
    "model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_4.ckpt', config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.set_train(False)\n",
    "model.config.eos_token_id = model.config.sep_token_id\n",
    "i = 0\n",
    "for (input_ids, raw_summary) in batched_test_dataset.create_tuple_iterator():\n",
    "    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)\n",
    "    output_text = tokenizer.decode(output_ids[0].tolist())\n",
    "    print(output_text)\n",
    "    i += 1\n",
    "    if i == 10:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
